Skip to content

fix: graceful fallback when attention backends fail to import#13060

Open
sym-bot wants to merge 2 commits intohuggingface:mainfrom
sym-bot:fix/graceful-attention-fallback
Open

fix: graceful fallback when attention backends fail to import#13060
sym-bot wants to merge 2 commits intohuggingface:mainfrom
sym-bot:fix/graceful-attention-fallback

Conversation

@sym-bot
Copy link

@sym-bot sym-bot commented Jan 31, 2026

Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be installed but fail to import at runtime due to ABI mismatches. For example, when flash_attn is compiled against PyTorch 2.4 but used with PyTorch 2.8, the import fails with:

OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab

The current code uses importlib.util.find_spec() to check if packages exist, but this only verifies the package is installed—not that it can actually be imported. When the import fails, diffusers crashes instead of falling back to native PyTorch attention.

Solution

Wrap all external attention backend imports in try-except blocks that catch ImportError and OSError. On failure:

  1. Log a warning message explaining the issue
  2. Set the corresponding _CAN_USE_* flag to False
  3. Set the imported functions to None

This allows diffusers to gracefully degrade to PyTorch's native SDPA (scaled_dot_product_attention) instead of crashing.

Affected backends

  • flash_attn (Flash Attention)
  • flash_attn_3 (Flash Attention 3)
  • aiter (AMD Instinct)
  • sageattention (SageAttention)
  • flex_attention (PyTorch Flex Attention)
  • torch_npu (Huawei NPU)
  • torch_xla (TPU/XLA)
  • xformers (Meta xFormers)

Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).

  • Before: crashes on from diffusers import ... with undefined symbol error
  • After: logs warning and uses native attention successfully

Example warning output

WARNING:diffusers.models.attention_dispatch:flash_attn is installed but failed to import: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab. Falling back to native PyTorch attention.

## Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be
installed but fail to import at runtime due to ABI mismatches. For example,
when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8,
the import fails with:

```
OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab
```

The current code uses `importlib.util.find_spec()` to check if packages exist,
but this only verifies the package is installed—not that it can actually be
imported. When the import fails, diffusers crashes instead of falling back to
native PyTorch attention.

## Solution

Wrap all external attention backend imports in try-except blocks that catch
`ImportError` and `OSError`. On failure:
1. Log a warning message explaining the issue
2. Set the corresponding `_CAN_USE_*` flag to `False`
3. Set the imported functions to `None`

This allows diffusers to gracefully degrade to PyTorch's native SDPA
(scaled_dot_product_attention) instead of crashing.

## Affected backends

- flash_attn (Flash Attention)
- flash_attn_3 (Flash Attention 3)
- aiter (AMD Instinct)
- sageattention (SageAttention)
- flex_attention (PyTorch Flex Attention)
- torch_npu (Huawei NPU)
- torch_xla (TPU/XLA)
- xformers (Meta xFormers)

## Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).
Before: crashes on import. After: logs warning and uses native attention.
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽 Just some minor requests.

except (ImportError, OSError) as e:
# Handle ABI mismatch or other import failures gracefully.
# This can happen when flash_attn was compiled against a different PyTorch version.
_flash_attn_logger = get_logger(__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use add a single logger at the beginning of the file and reuse it instead of creating a dedicated one for each backend.

try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
except (ImportError, OSError) as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can include RuntimeError in the exceptions list as well.

- Move logger to module level instead of creating per-backend loggers
- Add RuntimeError to exception list alongside ImportError and OSError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments